import tensorflow as tf

from data_loader.unet_data_generator import DataGenerator, ValDataGenerator
from models.unet_model import UNET_Model_5
from trainers.unet_trainer import UNET_Trainer
from utils.config import process_config
from utils.dirs import create_dirs
from utils.utils import get_args
# python3 unet_main_train.py -c configs/train.json


def main():

    try:
        args = get_args()
        print(args.config)
        config = process_config(args.config)

    except:
        print("missing or invalid arguments")
        exit(1)

    create_dirs([config.summary_dir, config.checkpoint_dir])
    
    

    if config.resume ==0:
        model = UNET_Model_5(config)
    elif config.resume == 1:
        model = tf.keras.models.load_model(config.loadmodel_dir)
    model.summary()
    
    data=data = DataGenerator(config,1)    
    valdata = ValDataGenerator(config)
    trainer = UNET_Trainer(model, data,valdata, config)
    trainer.train()


if __name__ == '__main__':
    main()
